import numpy as np
import pandas as pd
import copy
import scipy
import scipy.stats
try:
    from src.interpolation.splines import eval_linear
except:
    from ..interpolation.splines import eval_linear

from src.models_new import surplus
from src.models_new import prices
from src.models_new import model_objects


def get_targeting_shares(state, params, surp_array_by_tau, shares_guess, mri_grid,
                         entry_prob_by_tau=None,
                         use_absolute_advantage=False,
                         well_target=True):
    """ Compute the share of well draws which in equilibrium target
    each type of rig.

    Args:
        state: The current state
        y_min: A dictionary of the minimum mri (= 0 usually). Each
            key is in [low, mid, high].
        y_max: A dictionary of the maximium mri (use the empirical
            maximum - this is before the accept/reject decision so
            these should not be the cutoffs). Each key is in [low,
            mid, high].

    Returns:
        shares: an array that gives the share of wells that enter and
            target a particular type of rig. This is in the form of
            an array:

                [share_low, share_mid, share_high]

            In addition, 1 - share_low - share_mid - share_high is
            the share of wells that do not enter.
    """
    shares = shares_guess

    k = 0
    while k < 2000:
        output_dict, _, _, entry_prob_by_tau_new = model_objects.update_shares(
            mri_grid, params['delta'], state, shares, params,
            surp_array_by_tau, verbose=True, entry_prob_by_tau=entry_prob_by_tau,
            use_absolute_advantage=use_absolute_advantage, well_target=well_target)

        shares_list = []
        for spec in ['low', 'mid', 'high']:
            i = sum([params[f'p_{tau}'] * output_dict[(spec, tau)]
                     for tau in [2, 3, 4]])
            shares_list.append(i)

        s_1 = np.array(shares_list)

        if np.max(np.abs(s_1 - shares)) < 0.000000001:
            break
        elif k <= 100:
            shares = s_1
        else:
            shares = 0.5 * shares + 0.5 * s_1

        k += 1
    else:
        print("Error: Target PDF failed to converge")
        print("FAIL AT", params)
        print("DIFF", s_1, shares)

    return shares, entry_prob_by_tau_new


def get_moments(state_detail, mri_state_grid, n_rigs):
    moments_sim = dict()
    for spec in ['low', 'mid', 'high']:
        # Get complexity of contracts that are new/extended in the current period
        prob_match = (
            state_detail[spec][0][:, -1]
            + state_detail[spec][1][:, -1]
            + state_detail[spec][2][:, -1]
        )
        mean_match = (prob_match * mri_state_grid).sum() / prob_match.sum()
        moments_sim[f'mri_mean_{spec}'] = mean_match
        moments_sim[f'mri_variance_{spec}'] = (
            prob_match * (mri_state_grid - mean_match) ** 2).sum() / prob_match.sum()

        moments_sim[f'utilization_{spec}'] = (
            state_detail[spec][0].sum().sum()
            + state_detail[spec][1].sum().sum()
            + state_detail[spec][2].sum().sum()
        )

    for spec in ['low', 'mid', 'high']:
        # NOTE: n_matches includes extensions...
        moments_sim[f'n_matches_{spec}'] = 0
        for i, tau in enumerate([2, 3, 4]):
            moments_sim[f'p_new_match_{tau}_{spec}'] = state_detail[spec][i][:, -1].sum()
            moments_sim[f'n_matches_{spec}'] += n_rigs[spec] * moments_sim[f'p_new_match_{tau}_{spec}']

    return moments_sim


def get_aggregated_moments(moments, g_data, n_rigs):

    # Get utilization moments
    moments_agg = dict()
    for spec in ['low', 'mid', 'high']:
        moments_agg[f'utilization_covariance_{spec}'] = np.cov(moments[f'utilization_{spec}'], g_data['g'])[0, 1]
        moments_agg[f'utilization_variance_{spec}'] = np.var(moments[f'utilization_{spec}'])
        moments_agg[f'utilization_mean_{spec}'] = np.mean(moments[f'utilization_{spec}'])
        moments_agg[f'util_2006_{spec}'] = np.mean(moments.loc[g_data['2006'], f'utilization_{spec}'])

    # Get contract length moments
    contract_total = dict()
    for tau in [2, 3, 4]:
        contract_total[f'p_{tau}'] = 0
        for spec in ['low', 'mid', 'high']:
            contract_total[f'p_{tau}'] = np.sum(moments[f'p_new_match_{tau}_{spec}']) * n_rigs[spec]
    contract_total_denom = contract_total['p_2'] + contract_total['p_3'] + contract_total['p_4']
    moments_agg['p_2'] = contract_total['p_2'] / contract_total_denom
    moments_agg['p_3'] = contract_total['p_3'] / contract_total_denom

    # Get complexity moments
    mask_boom = (g_data['g'] >= g_data['g'].mean())
    for spec in ['low', 'mid', 'high']:
        moments_agg[f'mri_mean_{spec}_boom'] = \
            (moments[f'mri_mean_{spec}'] * moments[f'n_matches_{spec}'])[mask_boom].sum() / (moments[f'n_matches_{spec}'])[mask_boom].sum()
        moments_agg[f'mri_mean_{spec}_bust'] = \
            (moments[f'mri_mean_{spec}'] * moments[f'n_matches_{spec}'])[~mask_boom].sum() / (moments[f'n_matches_{spec}'])[~mask_boom].sum()

    # Get variance of complexity
    total_matches = 0
    numerator_var = 0
    for spec in ['low', 'mid', 'high']:
        total_matches += moments[f'n_matches_{spec}'].sum()
        numerator_var += (moments[f'n_matches_{spec}'] * moments[f'mri_variance_{spec}']).sum()
    moments_agg['mri_variance'] = numerator_var / total_matches

    return moments_agg


def get_extensions_moment(eta, df_contracts, surplus_grid, surplus_values_by_tau_spec):
    """ Get the model-predicted probability of inspection using the
    empirical data on contracts.

    """

    for tau in [2, 3, 4]:
        reneg_list = list()
        for spec in ['low', 'mid', 'high']:
            a = df_contracts.loc[(
                    (df_contracts['spec'] == spec)
                    & (df_contracts['tau'] == tau)
                ), ['mri', 'g_end', 'n_l_end', 'n_m_end', 'n_h_end']]
            a = np.ascontiguousarray(a)
            surp_spec = eval_linear(
                surplus_grid,
                surplus_values_by_tau_spec[(tau, spec)],
                a
            )
            reneg_list.append(eta * (surp_spec > 0))
    mean_reneg = np.concatenate(reneg_list).mean()
    return mean_reneg


def get_objective_function(moments_sim, moments_data, coefs_predicted,  coefs_data,
                           mean_reneg, mean_reneg_data, weights):
    """ Build the objective function.

    """
    objective = 0
    deviations = dict()
    for k in moments_data.index:
        deviations[k] = moments_sim[k] - moments_data[k]
        objective += deviations[k]**2 * weights[k]

    for j in coefs_data.keys():
        deviations[j] = coefs_predicted[j] - coefs_data[j]
        objective += deviations[j]**2 * weights[j]

    # Get price diff moments
    deviations['diff_price_mid'] = (
            (coefs_predicted['price_mid'] - coefs_predicted['price_low'])
            - (coefs_data['price_mid'] - coefs_data['price_low'])
    )
    deviations['diff_price_high'] = (
            (coefs_predicted['price_high'] - coefs_predicted['price_mid'])
            - (coefs_data['price_high'] - coefs_data['price_mid'])
    )
    objective += deviations['diff_price_mid'] ** 2 * weights['diff_price_mid']
    objective += deviations['diff_price_high'] ** 2 * weights['diff_price_high']

    objective += (mean_reneg - mean_reneg_data)**2 * weights['extension']

    return objective, deviations


def do_one_simulation(row, state_detail, params, config, well_target=True):
    """ Do simulation at one date.

    """
    state_surplus = np.array([row.g, row.n_l, row.n_m, row.n_h])

    # GET CUTOFFS -----------------------------------------------------------------------
    cutoff_min, cutoff_max, cutoff_mask_by_spec = surplus.get_cutoffs(
        state=state_surplus,
        value_zero=config['value_zero'],
        surplus_grid=config['surplus_grid'],
        surplus_values_by_spec=config['surplus_values_by_tau_spec'],
        mri_max=config['mri_max'],
        mri_state_grid=config['mri_state_grid']
    )

    # (POTENTIALLY) GET ENTRY DIST ------------------------------------------------------
    if config['entry_prob_by_tau_by_ym'] is None:
        entry_prob_by_tau = None
    else:
        entry_prob_by_tau = config['entry_prob_by_tau_by_ym'][row.date]

    # GET SURPLUS -----------------------------------------------------------------------
    surp_array_by_tau = surplus.get_surplus_array(
        mri=config['mri_state_grid'],
        state=state_surplus,
        surplus_grid=config['surplus_grid'],
        surplus_values_by_tau_spec=config['surplus_values_by_tau_spec'],
        delta=params['delta'],
        well_outside_option_by_tau_spec=config['well_outside_option_by_tau_spec']
    )

    # GET EXTENSIONS --------------------------------------------------------------------
    extensions_by_spec_tau = model_objects.get_extensions(
        cutoff_mask_by_spec=cutoff_mask_by_spec,
        params=params,
        state_detail=state_detail,
    )

    # UPDATE STATE SIMPLE ---------------------------------------------------------------
    state_simple = model_objects.get_state_simple(
        extensions_by_spec_tau=extensions_by_spec_tau,
        state_detail=state_detail,
        n_rigs=config['n_rigs'],
        g=row.g
    )

    # GET TARGETING ---------------------------------------------------------------------
    shares, entry_prob_by_tau_new = get_targeting_shares(
        state=state_simple,
        params=params,
        surp_array_by_tau=surp_array_by_tau,
        shares_guess=config['shares_guess'],
        mri_grid=config['mri_state_grid'],
        entry_prob_by_tau=entry_prob_by_tau,
        use_absolute_advantage=config['use_absolute_advantage'],
        well_target=well_target
    )

    # UPDATE THE DETAILED STATE ---------------------------------------------------------
    state_all = model_objects.update_state_all(
        cutoffs_mask=cutoff_mask_by_spec,
        state_detail=copy.deepcopy(state_detail),
        state=state_simple,
        shares=shares,
        mri_state_grid=config['mri_state_grid'],
        params=params,
        surp_array_by_tau=surp_array_by_tau,
        extensions_by_spec_tau=extensions_by_spec_tau,
        n_rigs=config['n_rigs'],
        entry_prob_by_tau=entry_prob_by_tau,
        verbose=config['verbose'],
        use_absolute_advantage=config['use_absolute_advantage'],
        well_target=well_target
    )

    return {
        'cutoff_min': cutoff_min,
        'cutoff_max': cutoff_max,
        'cutoff_mask_by_spec': cutoff_mask_by_spec,
        'shares': shares,
        'state_simple': state_simple,
        'state_detail': state_all['state_detail'],
        'state_all': state_all,
        'entry_prob_by_tau_new': entry_prob_by_tau_new,
        'prob_extension': extensions_by_spec_tau,
        'prob_match_by_spec_tau': state_all['prob_match_by_spec_tau']
    }


def do_simulation(state_data, params, surplus_grid, surplus_values_by_tau_spec,
                  n_rigs, mri_max, value_zero=False, burn_iter_max=5000,
                  shares_guess=np.array([1/3, 1/3, 1/3]), entry_prob_by_tau_by_ym=None,
                  verbose=False, well_outside_option_by_tau_spec=None,
                  use_absolute_advantage=False, tau_multiplier=1, well_target=True):
    """ Run the entire simulation for a fixed set of parameters
    (given in params).

    Args:
        state_data (pd.DataFrame): data for the exact state used
        params (dict): names parameters to use in the model
        surplus_grid (UCgrid): grid for surplus interpolation
        surplus_values_by_tau_spec (dict): values for computation of
            the match surplus (where the keys are (tau, spec)).
        n_rigs (dict): number of rigs in the market, keys are [low, mid, high]
        mri_max (float): upper bound of well complexity in the model
        value_zero (boolean): whether or not the dynamic value is zero
        burn_iter_max (int): maximum iterations to burn in (i.e. if code
            has not fully converged by this point then move on to simulation)
        shares_guess (np.array): initial guess of the targeting shares (each
            element corresponds to the guess of a rig type in [low, mid, high])
        entry_prob_by_tau_by_ym (dict): nested dictionary of the entry prob.
            of each contract length by the date (month)
        verbose (boolean): whether to surface lots of results
        well_outside_option_by_tau_spec (dict): outside option of the well (note
            that this may be defunct in the latest version)
        use_absolute_advantage (boolean): whether to work with absolute
            advantage (note that this may be defunct in the latest version)
        tau_multiplier (float): what to multiply the tau length by (this is
            mainly used when testing alternative time periods e.g. fortnightly
            when 2 month contracts -> 4 period fortnightly contracts)
        well_target (boolean): whether the wells target rigs or the rigs target
            wells. Default is that wells target rigs (this functionality
            is used as a robustness check).

    Returns:
        output (dict)

    """

    cutoffs_min_by_state = list()
    cutoffs_max_by_state = list()
    shares_by_state = list()
    moments_by_state = list()

    state_detail_by_ym = dict()
    state_simple_by_ym = dict()
    entry_prob_by_tau_by_ym_new = dict()
    prob_new_match_by_ym = dict()
    prob_extension_by_ym = dict()
    share_available_by_ym = dict()

    # INITIALIZE ------------------------------------------------------------------------
    mri_state_grid_length = 30
    mri_state_grid = np.linspace(0, mri_max, mri_state_grid_length)

    state_detail_2 = np.zeros((mri_state_grid_length, 2 * tau_multiplier))
    state_detail_3 = np.zeros((mri_state_grid_length, 3 * tau_multiplier))
    state_detail_4 = np.zeros((mri_state_grid_length, 4 * tau_multiplier))

    state_detail = {
        'low': copy.deepcopy([state_detail_2, state_detail_3, state_detail_4]),
        'mid': copy.deepcopy([state_detail_2, state_detail_3, state_detail_4]),
        'high': copy.deepcopy([state_detail_2, state_detail_3, state_detail_4])
    }

    state_surplus = (
        state_data
        .loc[state_data['date'] == '2000-01-01', ['g', 'n_l', 'n_m', 'n_h']]
        .values[0]
    )
    state_simple_old = copy.copy(state_surplus)
    state_simple = np.array([0.0, 0.0, 0.0, 0.0])

    config = {
        'value_zero': value_zero,
        'surplus_grid': surplus_grid,
        'surplus_values_by_tau_spec': surplus_values_by_tau_spec,
        'mri_max': mri_max,
        'mri_state_grid': mri_state_grid,
        'entry_prob_by_tau_by_ym': entry_prob_by_tau_by_ym,
        'n_rigs': n_rigs,
        'shares_guess': shares_guess,
        'verbose': verbose,
        'well_outside_option_by_tau_spec': well_outside_option_by_tau_spec,
        'use_absolute_advantage': use_absolute_advantage,
        'tau_multiplier': tau_multiplier
    }

    # DO THE BURN IN --------------------------------------------------------------------
    row = (
        state_data
        .loc[state_data['date'] == '2000-01-01', ['date', 'g', 'n_l', 'n_m', 'n_h']]
        .iloc[0]
    )

    i = 0
    while i < burn_iter_max:
        one_sim_output = do_one_simulation(row, state_detail, params, config, well_target)

        if np.max(np.abs(state_simple - state_simple_old)) < 0.00001:
            break

        state_simple_old = copy.deepcopy(one_sim_output['state_simple'])
        state_detail = copy.deepcopy(one_sim_output['state_detail'])

        i = i + 1

    # DO THE SIMULATION -----------------------------------------------------------------
    t = 0
    for row in state_data.itertuples():
        one_sim_output = do_one_simulation(
            row=row,
            state_detail=state_detail,
            params=params,
            config=config,
            well_target=well_target
        )

        moments = get_moments(
            one_sim_output['state_detail'],
            config['mri_state_grid'],
            config['n_rigs']
        )

        # Update state detail
        state_detail = copy.deepcopy(one_sim_output['state_detail'])

        # Save all results
        cutoffs_min_by_state.append(one_sim_output['cutoff_min'])
        cutoffs_max_by_state.append(one_sim_output['cutoff_max'])
        shares_by_state.append(one_sim_output['shares'])
        moments_by_state.append(moments)

        entry_prob_by_tau_by_ym_new[row.date] = one_sim_output['entry_prob_by_tau_new']
        state_detail_by_ym[row.date] = one_sim_output['state_detail']
        state_simple_by_ym[row.date] = one_sim_output['state_simple']
        prob_new_match_by_ym[row.date] = one_sim_output['prob_match_by_spec_tau'],
        prob_extension_by_ym[row.date] = one_sim_output['prob_extension'],
        share_available_by_ym[row.date] = one_sim_output['shares'],

        t += 1

    output = {
        'moments_by_state': moments_by_state,
        'cutoffs_min_by_state': cutoffs_min_by_state,
        'cutoffs_max_by_state': cutoffs_max_by_state,
        'state_detail_by_ym': state_detail_by_ym,
        'shares_by_state': shares_by_state,
        'state_simple_by_ym': state_simple_by_ym,
        'entry_prob_by_tau_by_ym': entry_prob_by_tau_by_ym_new,
        'prob_new_match_by_ym': prob_new_match_by_ym,
        'prob_extension_by_ym': prob_extension_by_ym,
        'share_available_by_ym': share_available_by_ym
    }

    return output


#%% PUT IT ALL TOGETHER WITH RUNNING STEPS ----------------------------------------------
def run_steps(x, args_by_names):
    """ Run all of the steps to simulate the model. Use with the optimization code.

    Args:
        x (list): a guess of the parameters for optimization. Combine with x_names
            from args_by_names['x_names'] to transform into a dictionary of
            parameters.
        args_by_names (dict): inputs to the optimization with keys (+ example):
            'x_names': list(bounds.keys()),
            'params_fixed': params_fixed,
            'weights': weights,
            'mri_max': mri_max,
            'verbose': False,
            'verbose_output': False,
            'value_zero': False,
            'entry_prob_by_tau_by_ym': None,
            'tau_multiplier': tau_multiplier,
            'well_target': True

    Returns:
        If args_by_names['verbose_output'] is False:
            output (float): objective function
        If args_by_names['verbose_output'] is True:
            output (dict): dictionary with keys as follows.
               'objective': objective function,
               'cutoffs_min': dataframe of minimum values of the cutoffs,
               'cutoffs_max': dataframe of maximum values of the cutoffs,
               'moments_sim': simulated moments (except price moments and
                    extension moments),
               'coefs_predicted': simulated moments from the auxiliary
                    price regression.
               'mean_reneg': simulated probability of extension.
               'prices_new_by_spec': dictionary of predicted prices
                **output: unpack and surface other interesting model results

    """

    # Set up the parameters as a dictionary
    params = dict(zip(args_by_names['x_names'], x))
    params = {**params, **args_by_names['params_fixed']}

    # Pre-compute some important quantities
    params['p_4'] = 1.0 - params['p_3'] - params['p_2']
    params['a_0'] = np.array([1.0, 1.0, 1.0])
    params['a_1'] = np.array([params['a_1_low'], params['a_1_mid'], params['a_1_high']])
    params['denom'] = (
        scipy.stats.norm.cdf(
            args_by_names['mri_max'],
            loc=params['mu_0'],
            scale=params['sigma_0']
        )
        - scipy.stats.norm.cdf(
            0,
            loc=params['mu_0'],
            scale=params['sigma_0']
        )
    )

    # Build the surplus values to interpolate
    (
        surplus_values_by_tau_spec,
        well_outside_option_by_tau_spec
    ) = surplus.build_fast_surplus(
        args_by_names['match_values_by_tau_spec'],
        params,
        args_by_names['surplus_grid'],
        args_by_names['non_myopic_dict']
    )

    # Do the simulation (supply side - i.e. the rig dynamics)
    output = do_simulation(
        state_data=args_by_names['state_data'],
        params=params,
        surplus_grid=args_by_names['surplus_grid'],
        surplus_values_by_tau_spec=surplus_values_by_tau_spec,
        n_rigs=args_by_names['n_rigs'],
        mri_max=args_by_names['mri_max'],
        value_zero=args_by_names['value_zero'],
        burn_iter_max=1000,
        shares_guess=np.array([1/3, 1/3, 1/3]),
        entry_prob_by_tau_by_ym=args_by_names['entry_prob_by_tau_by_ym'],
        verbose=args_by_names['verbose'],
        well_outside_option_by_tau_spec=well_outside_option_by_tau_spec,
        tau_multiplier=args_by_names['tau_multiplier'],
        well_target=args_by_names['well_target']
    )
    moments_by_state = pd.DataFrame(output['moments_by_state'])

    # Get the aggregate moments from the simulation
    moments_sim = get_aggregated_moments(
        moments=moments_by_state,
        g_data=args_by_names['g_data'],
        n_rigs=args_by_names['n_rigs']
    )

    # Get the price moments
    coefs_predicted, prices_new_by_spec = prices.get_fast_price_moments(
        params=params,
        price_match_values_by_spec=args_by_names['price_match_values_by_spec'],
        values_by_spec=args_by_names['values_by_spec'],
        df_contracts=args_by_names['df_contracts'],
        non_myopic_dict=args_by_names['non_myopic_dict']
    )

    # Get the extension moments
    mean_reneg = get_extensions_moment(
        eta=params['eta'],
        df_contracts=copy.deepcopy(args_by_names['df_contracts']),
        surplus_grid=args_by_names['surplus_grid'],
        surplus_values_by_tau_spec=surplus_values_by_tau_spec
    )

    # Get the objective function
    objective, deviations = get_objective_function(
        moments_sim=moments_sim,
        moments_data=args_by_names['moments_data'],
        coefs_predicted=coefs_predicted,
        coefs_data=args_by_names['coefs_data'],
        mean_reneg=mean_reneg,
        mean_reneg_data=args_by_names['mean_reneg_data'],
        weights=args_by_names['weights']
    )

    # Print out some results
    print('OBJECTIVE', objective, flush=True)

    if args_by_names['verbose']:
        print(pd.Series(moments_sim))
        print(coefs_predicted)
        print(deviations)
    if np.isnan(objective):
        print(params)
        print(moments_sim)

    if not args_by_names['verbose_output']:
        return objective

    if args_by_names['verbose_output']:
        return {
           'objective': objective,
           'cutoffs_min': pd.DataFrame(output['cutoffs_min_by_state']),
           'cutoffs_max': pd.DataFrame(output['cutoffs_max_by_state']),
           'moments_sim': moments_sim,
           'coefs_predicted': coefs_predicted,
           'mean_reneg': mean_reneg,
           'prices_new_by_spec': prices_new_by_spec,
            **output
        }


#%% POST-SIMULATION RESULTS PROCESSING --------------------------------------------------
def get_welfare_from_simulation(state_detail, params, mri_max, n_rigs, state_data):
    """ Get the welfare over time from each simulation

    Args:
        state_detail:
        params:

    Returns:

    """
    mri_state_grid_length = 30
    mri_state_grid = np.linspace(0, mri_max, mri_state_grid_length)

    total_value = dict()
    for i in state_detail:
        for spec in ['low', 'mid', 'high']:
            state = state_data[state_data['date'] == i]
            gas_price_portion = (
                params['m_2'] * state['g'].iloc[0] * (
                    params['rho_0']
                    + params['rho_1'] * mri_state_grid
                    + params['rho_2'] * mri_state_grid * mri_state_grid
                    + params['rho_3'] * mri_state_grid * mri_state_grid * mri_state_grid
                )
            )
            for k, tau in enumerate([2, 3, 4]):
                output = (
                    (
                        gas_price_portion
                        + params[f'm_1_{spec}'] * mri_state_grid
                        + params[f'm_0_{spec}']
                    )
                    @ state_detail[i][spec][k]
                ).sum()

                total_value[(spec, i, tau)] = output

    total_value = pd.Series(total_value)

    total_value_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        total_value_by_spec[spec] = (
            total_value.xs(spec)
            .groupby(level=[0])
            .sum() * n_rigs[spec]
        )

    return total_value_by_spec


def get_welfare_initial_match_only(state_detail, params, mri_max, n_rigs, state_data):
    """ Get the welfare over time from each simulation, looking only at the initial match

    """
    mri_state_grid_length = 30
    mri_state_grid = np.linspace(0, mri_max, mri_state_grid_length)

    total_value = dict()
    for i in state_detail:
        for spec in ['low', 'mid', 'high']:
            state = state_data[state_data['date'] == i]
            gas_price_portion = (
                    params['m_2'] * state['g'].iloc[0] * (
                    params['rho_0']
                    + params['rho_1'] * mri_state_grid
                    + params['rho_2'] * mri_state_grid * mri_state_grid
                    + params['rho_3'] * mri_state_grid * mri_state_grid * mri_state_grid
                )
            )
            for k, tau in enumerate([2, 3, 4]):
                output = (
                        (
                            gas_price_portion
                            + params[f'm_1_{spec}'] * mri_state_grid
                            + params[f'm_0_{spec}']
                        )
                        @ state_detail[i][spec][k][:, -1]
                ).sum()

                if tau == 2:
                    total_value[(spec, i, tau)] = output * (1 + 0.99)
                if tau == 3:
                    total_value[(spec, i, tau)] = output * (1 + 0.99 + 0.99 ** 2)
                if tau == 4:
                    total_value[(spec, i, tau)] = output * (1 + 0.99 + 0.99 ** 2 + 0.99 ** 3)

    total_value = pd.Series(total_value)

    total_value_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        total_value_by_spec[spec] = (
                total_value.xs(spec)
                .groupby(level=[0])
                .sum() * n_rigs[spec]
        )

    return total_value_by_spec


